# ------------------------------------------------------------------------------
# Generates model predictions in the untested sets
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=25000]" bash 05_examine_outcomes.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(Matrix)
library(glmnet)
library(xgboost)
library(optparse)
library(testit) # assert
library(here) # here() relative filepaths

u <- modules::use(here::here("lib", "util.R"))
a <- modules::use(here::here("lib", "aesthetics.R"))

# Helpers ----------------------------------------------------------------------
ntile_within <- function(x, n_tiles, within = seq_along(x)) {
  cutpoints <- quantile(x[within], seq(1, n_tiles - 1) / n_tiles)

  cutpoints_mtx <- matrix(cutpoints,
    nrow = length(x),
    ncol = length(cutpoints),
    byrow = TRUE
  )

  tile <- 1L + rowSums(x > cutpoints_mtx)

  as.integer(tile)
}

get_population <- function(outcome, ...) {
  population <- case_when(
    grepl("mace", outcome) ~ "untested",
    grepl("stent", outcome) ~ "tested",
    TRUE ~ "all"
  )
}

get_outcome_means <- function(outcome, stent, population, fold, ...) {
  message(outcome, stent, population)
  if (population == "untested") {
    df <- filter(ids, !test_010_day)
  } else if (population == "tested") {
    df <- filter(ids, test_010_day)
  } else {
    df <- ids
  }
  df <- filter(df, train_fold == fold)

  df$grouping <- population
  means <- df %>%
    group_by(eval(as.name(stent)), grouping) %>%
    summarize(
      outcome_rate = mean(eval(as.name(outcome))),
      n = n(),
      std.deviation = sd(eval(as.name(outcome))),
      std.error = std.deviation / sqrt(n)
    ) %>%
    select(-n, -std.deviation) %>%
    setnames("eval(as.name(stent))", "x_var")
}

add_labels <- function(gg, outcome, stent, population, fold, ...) {
  xlabel <- case_when(
    stent == "tile_stent_or_cabg_010_tested" ~ "Stent-Hat Decile",
    TRUE ~ stent
  )
  ylabel <- case_when(
    outcome == "mace_030_day" ~ "MACE Rate (30 day)",
    outcome == "stent_or_cabg_010_day" ~ "Yield Rate (10 day)",
    outcome == "test_010_day" ~ "Testing Rate (10 day)",
    TRUE ~ outcome
  )

  if (population == "untested") {
    df <- filter(ids, !test_010_day)
  } else if (population == "tested") {
    df <- filter(ids, test_010_day)
  } else {
    df <- ids
  }
  df <- filter(df, train_fold == fold)
  n <- nrow(df)

  title <- glue("{ylabel} by {xlabel} for OOS Training Predictions")
  subtitle <- glue("Training cohort ({population}), excluding patients with chronic illness history, untested AMI day of")
  caption <- glue("N = {n}")

  gg <- gg + labs(
    x = xlabel, y = ylabel, title = title,
    subtitle = subtitle, caption = caption
  )
}

get_filename <- function(outcome, stent, fold, ...) {
  filename <- file.path(temp, glue("{outcome}_by_{stent}_fold{fold}.png"))
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here::here("lib", "filepaths.yml"))
temp <- here::here("code", "03_analysis", "01_build_models", "08_oos_train_yhats", "temp")

scores <- readRDS(file.path(
  paths$modeling$oos, "prediction",
  "scores.rds"
))
ids <- readRDS(paths$modeling$train) %>%
  select(-train_fold) %>%
  u$safe_left_join(scores) %>%
  mutate(stent_hat = p__ensemble__stent_or_cabg_010_day__tested) %>%
  setDT()
# scores <- readRDS(file.path(paths$modeling$oos, "prediction",
#                             "overnight", "scores.rds"))
# ids <- readRDS(paths$modeling$train_overnight) %>%
#        select(-train_fold) %>%
#        u$safe_left_join(scores) %>%
#        mutate(stent_hat = p__ensemble__stent_or_cabg_010_day__tested) %>%
#        setDT

# Prep Data --------------------------------------------------------------------
message("Getting stent tiles...")
ids[, tile_stent_or_cabg_005_tested := ntile_within(stent_hat, 5, test_010_day == TRUE & excl_flag_c_int == FALSE & excl_flag_death == FALSE & excl_flag_chronic == FALSE)]
ids[, tile_stent_or_cabg_010_tested := ntile_within(stent_hat, 10, test_010_day == TRUE & excl_flag_c_int == FALSE & excl_flag_death == FALSE & excl_flag_chronic == FALSE)]
ids[, tile_stent_or_cabg_050_tested := ntile_within(stent_hat, 50, test_010_day == TRUE & excl_flag_c_int == FALSE & excl_flag_death == FALSE & excl_flag_chronic == FALSE)]
ids[, tile_stent_or_cabg_100_tested := ntile_within(stent_hat, 100, test_010_day == TRUE & excl_flag_c_int == FALSE & excl_flag_death == FALSE & excl_flag_chronic == FALSE)]

# Outcome Plots ----------------------------------------------------------------
ids <- ids %>%
  filter(not_ami_day_of | test_010_day) %>%
  filter(!excl_flag_c_int & !excl_flag_death & !excl_flag_chronic)

message("Making plots...")
config <- crossing(
  outcome = c("mace_030_day", "stent_or_cabg_010_day", "test_010_day"),
  stent = c("tile_stent_or_cabg_010_tested"),
  fold = seq(1, 5)
)
plots <- config %>%
  mutate(population = pmap(., get_population)) %>%
  mutate(mean_outcomes = pmap(., get_outcome_means)) %>%
  mutate(gg = pmap(., a$tile_plot)) %>%
  mutate(plot = pmap(., add_labels)) %>%
  mutate(filename = pmap(., get_filename))

# Save -------------------------------------------------------------------------
message("Saving...")
plots %>%
  select(plot, filename) %>%
  pmap(ggsave, width = 10, height = 7, units = "in")

# ------------------------------------------------------------------------------
message("Done.")
